{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d74ed9a5-72ee-40b5-8b75-bf206f9d81a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] GE_ADPT(62807,ffff9a23b010,python):2025-01-17-20:48:53.295.497 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleGetModelId failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleGetModelId\n",
      "[WARNING] GE_ADPT(62807,ffff9a23b010,python):2025-01-17-20:48:53.295.553 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleLoadFromMem failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleLoadFromMem\n",
      "[WARNING] GE_ADPT(62807,ffff9a23b010,python):2025-01-17-20:48:53.295.573 [mindspore/ccsrc/utils/dlopen_macro.h:163] DlsymAscend] Dynamically load symbol aclmdlBundleUnload failed, result = /usr/local/Ascend/ascend-toolkit/latest/lib64/libascendcl.so: undefined symbol: aclmdlBundleUnload\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:53.504.309 [mindspore/run_check/_check_version.py:329] MindSpore version 2.4.10 and Ascend AI software package (Ascend Data Center Solution)version 7.3 does not match, the version of software package expect one of ['7.5', '7.6']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:53.506.524 [mindspore/run_check/_check_version.py:407] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:56.245.61 [mindspore/run_check/_check_version.py:347] MindSpore version 2.4.10 and \"te\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:56.275.50 [mindspore/run_check/_check_version.py:354] MindSpore version 2.4.10 and \"hccl\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:56.283.37 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:57.301.31 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:48:58.320.09 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 1\n",
      "Building prefix dict from the default dictionary ...\n",
      "Loading model from cache /tmp/jieba.cache\n",
      "Loading model cost 1.277 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "import numpy as np\n",
    "import mindspore as ms\n",
    "from mindspore import ops\n",
    "from mindnlp.transformers import (\n",
    "    BertGenerationTokenizer,\n",
    "    BertGenerationDecoder,\n",
    "    BertGenerationConfig,\n",
    "    CLIPModel,\n",
    "    CLIPTokenizer\n",
    ")\n",
    "from loaders.ZO_Clip_loaders import tinyimage_single_isolated_class_loader\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from mindspore import context\n",
    "import sys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "27e5e750-9f2f-4580-a91c-0860e67e9f37",
   "metadata": {},
   "outputs": [],
   "source": [
    "def tokenize_for_clip(batch_sentences, tokenizer):\n",
    "    # 使用CLIPTokenizer直接处理\n",
    "    inputs = tokenizer(\n",
    "        batch_sentences,\n",
    "        padding=True,\n",
    "        truncation=True,\n",
    "        max_length=77,\n",
    "        return_tensors=\"ms\"\n",
    "    )\n",
    "    return inputs.input_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0c2a618c-d92a-46ea-9783-a99a0f83848a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def greedysearch_generation_topk(clip_embed, bert_model, batch_size=32):\n",
    "    # 处理多个样本\n",
    "    N = clip_embed.shape[0]\n",
    "    max_len = 77\n",
    "\n",
    "    # 初始化batch的target序列\n",
    "    target_lists = [[berttokenizer.bos_token_id] for _ in range(N)]\n",
    "    top_k_lists = [[] for _ in range(N)]\n",
    "    bert_model.set_train(False)\n",
    "\n",
    "    for i in range(max_len):\n",
    "        # 批量处理target序列\n",
    "        targets = ms.Tensor(target_lists, dtype=ms.int64)\n",
    "        position_ids = ms.Tensor(np.arange(targets.shape[1])[None].repeat(N, axis=0), ms.int32)\n",
    "        attention_mask = ops.ones((N, targets.shape[1]), dtype=ms.int32)\n",
    "\n",
    "        out = bert_model(\n",
    "            input_ids=targets,\n",
    "            attention_mask=attention_mask,\n",
    "            position_ids=position_ids,\n",
    "            encoder_hidden_states=clip_embed,\n",
    "        )\n",
    "\n",
    "        pred_idxs = out.logits.argmax(axis=2)[:, -1].astype(ms.int64)\n",
    "        _, top_k = ops.topk(out.logits, dim=2, k=35)\n",
    "\n",
    "        for j in range(N):\n",
    "            target_lists[j].append(pred_idxs[j].item())\n",
    "            top_k_lists[j].append(top_k[j, -1])\n",
    "\n",
    "        if all(len(t) >= 10 for t in target_lists):\n",
    "            break\n",
    "\n",
    "    results = []\n",
    "    for i in range(N):\n",
    "        top_k_tensor = ops.concat(top_k_lists[i])\n",
    "        target_tensor = ms.Tensor(target_lists[i], dtype=ms.int64)\n",
    "        results.append((target_tensor, top_k_tensor))\n",
    "\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ccf8ce47-027d-418d-8fbd-a1e3a019e383",
   "metadata": {},
   "outputs": [],
   "source": [
    "def image_decoder(clip_model, berttokenizer, split, image_loaders=None, bert_model=None):\n",
    "    seen_labels = split[:20]\n",
    "    seen_descriptions = [f\"This is a photo of a {label}\" for label in seen_labels]\n",
    "    targets = ms.Tensor(1000 * [0] + 9000 * [1], dtype=ms.int32)\n",
    "    max_num_entities = 0\n",
    "    ood_probs_sum = []\n",
    "\n",
    "    for semantic_label in tqdm(split):\n",
    "        # print(f\"处理类别: {semantic_label}\")\n",
    "        loader = image_loaders[semantic_label]\n",
    "\n",
    "        for batch_data in loader.create_dict_iterator():\n",
    "            batch_images = batch_data[\"image\"]\n",
    "            batch_size = batch_images.shape[0]\n",
    "\n",
    "            clip_model.set_train(False)\n",
    "            clip_out = clip_model.get_image_features(pixel_values=batch_images)\n",
    "            clip_extended_embed = ops.repeat_elements(clip_out, rep=2, axis=1)\n",
    "            clip_extended_embed = ops.expand_dims(clip_extended_embed, 1)\n",
    "\n",
    "            batch_results = greedysearch_generation_topk(clip_extended_embed, bert_model)\n",
    "            del clip_extended_embed\n",
    "            del clip_out\n",
    "\n",
    "            batch_target_tokens = []\n",
    "            batch_topk_tokens = []\n",
    "\n",
    "            for target_list, topk_list in batch_results:\n",
    "                target_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in target_list]\n",
    "                topk_tokens = [berttokenizer.decode(int(pred_idx.asnumpy())) for pred_idx in topk_list]\n",
    "                batch_target_tokens.append(target_tokens)\n",
    "                batch_topk_tokens.append(topk_tokens)\n",
    "\n",
    "            batch_unique_entities = []\n",
    "            for topk_tokens in batch_topk_tokens:\n",
    "                unique_entities = list(set(topk_tokens) - set(seen_labels))\n",
    "                batch_unique_entities.append(unique_entities)\n",
    "                max_num_entities = max(max_num_entities, len(unique_entities))\n",
    "\n",
    "            batch_all_desc = []\n",
    "            for unique_entities in batch_unique_entities:\n",
    "                all_desc = seen_descriptions + [f\"This is a photo of a {label}\" for label in unique_entities]\n",
    "                batch_all_desc.append(all_desc)\n",
    "\n",
    "            batch_all_desc_ids = [tokenize_for_clip(all_desc, cliptokenizer) for all_desc in batch_all_desc]\n",
    "\n",
    "            image_features = clip_model.get_image_features(pixel_values=batch_images)\n",
    "            image_features = image_features / ops.norm(image_features, dim=-1, keepdim=True)\n",
    "\n",
    "            for b_idx in range(len(batch_results)):\n",
    "                text_features = clip_model.get_text_features(input_ids=batch_all_desc_ids[b_idx])\n",
    "                text_features = text_features / ops.norm(text_features, dim=-1, keepdim=True)\n",
    "\n",
    "                similarity = 100.0 * (image_features[b_idx:b_idx + 1] @ text_features.T)\n",
    "                zeroshot_probs = ops.softmax(similarity, axis=-1).squeeze()\n",
    "\n",
    "                ood_prob_sum = float(ops.sum(zeroshot_probs[20:]).asnumpy())\n",
    "                ood_probs_sum.append(ood_prob_sum)\n",
    "\n",
    "            del batch_target_tokens\n",
    "            del batch_topk_tokens\n",
    "            del batch_unique_entities\n",
    "            del batch_all_desc\n",
    "            del image_features\n",
    "\n",
    "    auc_sum = roc_auc_score(targets.asnumpy(), np.array(ood_probs_sum))\n",
    "    print('当前split的sum_ood AUROC={}'.format(auc_sum))\n",
    "    return auc_sum\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4259ae55-5f6b-44aa-8ebb-edca0b43a18b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_args_in_notebook():\n",
    "    args = argparse.Namespace(\n",
    "        trained_path='./trained_models/COCO/'\n",
    "    )\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ceda8c1b-9202-4669-9083-62d959c5a7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:07.173.519 [mindspore/run_check/_check_version.py:329] MindSpore version 2.4.10 and Ascend AI software package (Ascend Data Center Solution)version 7.3 does not match, the version of software package expect one of ['7.5', '7.6']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:07.175.873 [mindspore/run_check/_check_version.py:407] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:07.176.678 [mindspore/run_check/_check_version.py:347] MindSpore version 2.4.10 and \"te\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:07.177.441 [mindspore/run_check/_check_version.py:354] MindSpore version 2.4.10 and \"hccl\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:07.178.188 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:08.179.920 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:09.181.915 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 1\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:10.184.315 [mindspore/run_check/_check_version.py:329] MindSpore version 2.4.10 and Ascend AI software package (Ascend Data Center Solution)version 7.3 does not match, the version of software package expect one of ['7.5', '7.6']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:10.186.492 [mindspore/run_check/_check_version.py:407] Can not find the tbe operator implementation(need by mindspore-ascend). Please check whether the Environment Variable PYTHONPATH is set. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:10.187.256 [mindspore/run_check/_check_version.py:347] MindSpore version 2.4.10 and \"te\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:10.188.013 [mindspore/run_check/_check_version.py:354] MindSpore version 2.4.10 and \"hccl\" wheel package version 7.3 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:10.188.765 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:11.190.514 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(62807:281473267773456,MainProcess):2025-01-17-20:49:12.192.633 [mindspore/run_check/_check_version.py:368] Please pay attention to the above warning, countdown: 1\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindnlp/transformers/tokenization_utils_base.py:1526: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted, and will be then set to `False` by default. \n",
      "  warnings.warn(\n",
      "[WARNING] DEVICE(62807,ffff9a23b010,python):2025-01-17-20:49:17.676.516 [mindspore/ccsrc/plugin/device/ascend/hal/device/ascend_vmm_adapter.h:188] CheckVmmDriverVersion] Driver version is less than 24.0.0, vmm is disabled by default, drvier_version: 23.0.rc2.2\n",
      "BertGenerationDecoder has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`.`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
      "  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
      "  - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e58b58d70a648e78c8e74be77143ba4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.8469840000000002\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "482624658ef04e8d838e890d3e303af6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.8381759999999999\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98218503e1f540b3baa28a73aff5a0d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.8535600000000001\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90e3744aa85d496382785d576d772211",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.8344800000000001\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "414f645d99504bafb4a1230d86ee3d64",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "当前split的sum_ood AUROC=0.8711199999999999\n",
      "5个split的sum auc分数: [0.8469840000000002, 0.8381759999999999, 0.8535600000000001, 0.8344800000000001, 0.8711199999999999]\n",
      "5个split的平均分数: 0.8488640000000001 标准差: 0.012977281317749062\n"
     ]
    }
   ],
   "source": [
    "if __name__ == '__main__':\n",
    "    # 判断是否在notebook环境\n",
    "    if 'ipykernel' in sys.modules or 'IPython' in sys.modules:\n",
    "        args = get_args_in_notebook()\n",
    "        context.set_context(device_target=\"Ascend\")\n",
    "    else:\n",
    "        parser = argparse.ArgumentParser()\n",
    "        parser.add_argument('--trained_path', type=str, default='./trained_models/COCO/')\n",
    "        args = parser.parse_args()\n",
    "        context.set_context(device_target=\"Ascend\")\n",
    "\n",
    "    args.saved_model_path = args.trained_path + '/ViT-B32/'\n",
    "\n",
    "    if not os.path.exists(args.saved_model_path):\n",
    "        os.makedirs(args.saved_model_path)\n",
    "\n",
    "    # 初始化tokenizers\n",
    "    berttokenizer = BertGenerationTokenizer.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder')\n",
    "\n",
    "    # 加载CLIP模型和tokenizer\n",
    "    model_name = 'openai/clip-vit-base-patch32'\n",
    "    try:\n",
    "        clip_model = CLIPModel.from_pretrained(model_name)\n",
    "        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading model from mirror, trying direct download: {e}\")\n",
    "        clip_model = CLIPModel.from_pretrained(model_name)\n",
    "        cliptokenizer = CLIPTokenizer.from_pretrained(model_name)\n",
    "\n",
    "    # 初始化BERT模型\n",
    "    if (not os.path.exists(f\"{args.saved_model_path}/decoder_model\")):\n",
    "        bert_config = BertGenerationConfig.from_pretrained(\"google/bert_for_seq_generation_L-24_bbc_encoder\")\n",
    "        bert_config.is_decoder = True\n",
    "        bert_config.add_cross_attention = True\n",
    "        bert_config.return_dict = True\n",
    "        bert_model = BertGenerationDecoder.from_pretrained(\"google/bert_for_seq_generation_L-24_bbc_encoder\",\n",
    "                                                           config=bert_config)\n",
    "    else:\n",
    "        bert_model = BertGenerationDecoder.from_pretrained(f\"{args.saved_model_path}/decoder_model\")\n",
    "\n",
    "    splits, tinyimg_loaders = tinyimage_single_isolated_class_loader(dataset_dir='./data/tiny-imagenet-200/val/',\n",
    "                                                                     labels_to_ids_path='./dataloaders/tinyimagenet_labels_to_ids.txt')\n",
    "\n",
    "    sum_scores = []\n",
    "    for split in splits:\n",
    "        sum_score = image_decoder(clip_model, berttokenizer, split=split,\n",
    "                                  image_loaders=tinyimg_loaders, bert_model=bert_model)\n",
    "        sum_scores.append(sum_score)\n",
    "\n",
    "    print('5个split的sum auc分数:', sum_scores)\n",
    "    print('5个split的平均分数:', np.mean(sum_scores), '标准差:', np.std(sum_scores))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78707ed9-d072-4325-9c58-0926326628ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
